{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch==2.2 in /opt/conda/lib/python3.10/site-packages (2.2.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2023.12.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2) (12.3.101)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch==2.2) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch==2.2) (1.3.0)\n",
      "Requirement already satisfied: torchvision==0.17 in /opt/conda/lib/python3.10/site-packages (0.17.0)\n",
      "Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17) (1.26.2)\n",
      "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17) (2.28.1)\n",
      "Requirement already satisfied: torch==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17) (2.2.0)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17) (10.2.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (2023.12.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2.0->torchvision==0.17) (12.3.101)\n",
      "Requirement already satisfied: charset-normalizer<3,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17) (2.1.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17) (3.3)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17) (1.26.10)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17) (2022.6.15)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch==2.2.0->torchvision==0.17) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch==2.2.0->torchvision==0.17) (1.3.0)\n",
      "Requirement already satisfied: nltk==3.7 in /opt/conda/lib/python3.10/site-packages (3.7)\n",
      "Requirement already satisfied: click in /opt/conda/lib/python3.10/site-packages (from nltk==3.7) (8.1.7)\n",
      "Requirement already satisfied: joblib in /opt/conda/lib/python3.10/site-packages (from nltk==3.7) (1.3.2)\n",
      "Requirement already satisfied: regex>=2021.8.3 in /opt/conda/lib/python3.10/site-packages (from nltk==3.7) (2022.7.9)\n",
      "Requirement already satisfied: tqdm in /opt/conda/lib/python3.10/site-packages (from nltk==3.7) (4.66.1)\n"
     ]
    }
   ],
   "source": [
    "!pip install torch==2.2\n",
    "!pip install torchvision==0.17\n",
    "!pip install nltk==3.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class InceptionModule(nn.Module):\n",
    "    def __init__(self, input_planes, n_channels1x1, n_channels3x3red, n_channels3x3, n_channels5x5red, n_channels5x5, pooling_planes):\n",
    "        super(InceptionModule, self).__init__()\n",
    "        # 1x1 convolution branch\n",
    "        self.block1 = nn.Sequential(\n",
    "            nn.Conv2d(input_planes, n_channels1x1, kernel_size=1),\n",
    "            nn.BatchNorm2d(n_channels1x1),\n",
    "            nn.ReLU(True),\n",
    "        )\n",
    " \n",
    "        # 1x1 convolution -> 3x3 convolution branch\n",
    "        self.block2 = nn.Sequential(\n",
    "            nn.Conv2d(input_planes, n_channels3x3red, kernel_size=1),\n",
    "            nn.BatchNorm2d(n_channels3x3red),\n",
    "            nn.ReLU(True),\n",
    "            nn.Conv2d(n_channels3x3red, n_channels3x3, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(n_channels3x3),\n",
    "            nn.ReLU(True),\n",
    "        )\n",
    " \n",
    "        # 1x1 conv -> 5x5 conv branch\n",
    "        self.block3 = nn.Sequential(\n",
    "            nn.Conv2d(input_planes, n_channels5x5red, kernel_size=1),\n",
    "            nn.BatchNorm2d(n_channels5x5red),\n",
    "            nn.ReLU(True),\n",
    "            nn.Conv2d(n_channels5x5red, n_channels5x5, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(n_channels5x5),\n",
    "            nn.ReLU(True),\n",
    "            nn.Conv2d(n_channels5x5, n_channels5x5, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(n_channels5x5),\n",
    "            nn.ReLU(True),\n",
    "        )\n",
    " \n",
    "        # 3x3 pool -> 1x1 conv branch\n",
    "        self.block4 = nn.Sequential(\n",
    "            nn.MaxPool2d(3, stride=1, padding=1),\n",
    "            nn.Conv2d(input_planes, pooling_planes, kernel_size=1),\n",
    "            nn.BatchNorm2d(pooling_planes),\n",
    "            nn.ReLU(True),\n",
    "        )\n",
    " \n",
    "    def forward(self, ip):\n",
    "        op1 = self.block1(ip)\n",
    "        op2 = self.block2(ip)\n",
    "        op3 = self.block3(ip)\n",
    "        op4 = self.block4(ip)\n",
    "        return torch.cat([op1,op2,op3,op4], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GoogLeNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GoogLeNet, self).__init__()\n",
    "        self.stem = nn.Sequential(\n",
    "            nn.Conv2d(3, 192, kernel_size=3, padding=1),\n",
    "            nn.BatchNorm2d(192),\n",
    "            nn.ReLU(True),\n",
    "        )\n",
    " \n",
    "        self.im1 = InceptionModule(192,  64,  96, 128, 16, 32, 32)\n",
    "        self.im2 = InceptionModule(256, 128, 128, 192, 32, 96, 64)\n",
    " \n",
    "        self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)\n",
    " \n",
    "        self.im3 = InceptionModule(480, 192,  96, 208, 16,  48,  64)\n",
    "        self.im4 = InceptionModule(512, 160, 112, 224, 24,  64,  64)\n",
    "        self.im5 = InceptionModule(512, 128, 128, 256, 24,  64,  64)\n",
    "        self.im6 = InceptionModule(512, 112, 144, 288, 32,  64,  64)\n",
    "        self.im7 = InceptionModule(528, 256, 160, 320, 32, 128, 128)\n",
    " \n",
    "        self.im8 = InceptionModule(832, 256, 160, 320, 32, 128, 128)\n",
    "        self.im9 = InceptionModule(832, 384, 192, 384, 48, 128, 128)\n",
    " \n",
    "        self.average_pool = nn.AvgPool2d(7, stride=1)\n",
    "        self.fc = nn.Linear(4096, 1000)\n",
    " \n",
    "    def forward(self, ip):\n",
    "        op = self.stem(ip)\n",
    "        out = self.im1(op)\n",
    "        out = self.im2(op)\n",
    "        op = self.maxpool(op)\n",
    "        op = self.a4(op)\n",
    "        op = self.b4(op)\n",
    "        op = self.c4(op)\n",
    "        op = self.d4(op)\n",
    "        op = self.e4(op)\n",
    "        op = self.max_pool(op)\n",
    "        op = self.a5(op)\n",
    "        op = self.b5(op)\n",
    "        op = self.avgerage_pool(op)\n",
    "        op = op.view(op.size(0), -1)\n",
    "        op = self.fc(op)\n",
    "        return op"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (Local)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
